{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wC0PtNm3Sa_T"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hgOqPjRKSa-7"
      },
      "outputs": [],
      "source": [
        "# Copyright 2019 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oKAkxAYuONU6"
      },
      "source": [
        "# Video Inbetweening using 3D Convolutions\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/tweening_conv3d\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/tweening_conv3d.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/tweening_conv3d.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/tweening_conv3d.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/tweening_conv3d_bair/1\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cvMgkVIBpT-Y"
      },
      "source": [
        "Yunpeng Li, Dominik Roblek, and Marco Tagliasacchi. From Here to There: Video Inbetweening Using Direct 3D Convolutions, 2019.\n",
        "\n",
        "https://arxiv.org/abs/1905.10240\n",
        "\n",
        "\n",
        "Current Hub characteristics:\n",
        "- has models for BAIR Robot pushing videos and KTH action video dataset (though this colab uses only BAIR)\n",
        "- BAIR dataset already available in Hub. However, KTH videos need to be supplied by the users themselves.\n",
        "- only evaluation (video generation) for now\n",
        "- batch size and frame size are hard-coded\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q4DN769E2O_R"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EsQFWvxrYrHg"
      },
      "source": [
        "Since `tfds.load('bair_robot_pushing_small', split='test')` would download a 30GB archive that also contains the training data, we download a separated archive that only contains the 190MB test data. The used dataset has been published by [this paper](https://arxiv.org/abs/1710.05268) and is licensed as Creative Commons BY 4.0."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GhIKakhc7JYL"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import tensorflow_hub as hub\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from tensorflow_datasets.core import SplitGenerator\n",
        "from tensorflow_datasets.video.bair_robot_pushing import BairRobotPushingSmall\n",
        "\n",
        "import tempfile\n",
        "import pathlib\n",
        "\n",
        "TEST_DIR = pathlib.Path(tempfile.mkdtemp()) / \"bair_robot_pushing_small/softmotion30_44k/test/\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zBMz14GmYkwz"
      },
      "outputs": [],
      "source": [
        "# Download the test split to $TEST_DIR\n",
        "!mkdir -p $TEST_DIR\n",
        "!wget -nv https://storage.googleapis.com/download.tensorflow.org/data/bair_test_traj_0_to_255.tfrecords -O $TEST_DIR/traj_0_to_255.tfrecords"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "irRJ2Q0iYoW0"
      },
      "outputs": [],
      "source": [
        "# Since the dataset builder expects the train and test split to be downloaded,\n",
        "# patch it so it only expects the test data to be available\n",
        "builder = BairRobotPushingSmall()\n",
        "test_generator = SplitGenerator(name='test', gen_kwargs={\"filedir\": str(TEST_DIR)})\n",
        "builder._split_generators = lambda _: [test_generator]\n",
        "builder.download_and_prepare()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iaGU8hhBPi_6"
      },
      "source": [
        "## BAIR: Demo based on numpy array inputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "IgWmW8YzEiDo"
      },
      "outputs": [],
      "source": [
        "# @title Load some example data (BAIR).\n",
        "batch_size = 16\n",
        "\n",
        "# If unable to download the dataset automatically due to \"not enough disk space\", please download manually to Google Drive and\n",
        "# load using tf.data.TFRecordDataset.\n",
        "ds = builder.as_dataset(split=\"test\")\n",
        "test_videos = ds.batch(batch_size)\n",
        "first_batch = next(iter(test_videos))\n",
        "input_frames = first_batch['image_aux1'][:, ::15]\n",
        "input_frames = tf.cast(input_frames, tf.float32)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "96Jd5XefGHRr"
      },
      "outputs": [],
      "source": [
        "# @title Visualize loaded videos start and end frames.\n",
        "\n",
        "print('Test videos shape [batch_size, start/end frame, height, width, num_channels]: ', input_frames.shape)\n",
        "sns.set_style('white')\n",
        "plt.figure(figsize=(4, 2*batch_size))\n",
        "\n",
        "for i in range(batch_size)[:4]:\n",
        "  plt.subplot(batch_size, 2, 1 + 2*i)\n",
        "  plt.imshow(input_frames[i, 0] / 255.0)\n",
        "  plt.title('Video {}: First frame'.format(i))\n",
        "  plt.axis('off')\n",
        "  plt.subplot(batch_size, 2, 2 + 2*i)\n",
        "  plt.imshow(input_frames[i, 1] / 255.0)\n",
        "  plt.title('Video {}: Last frame'.format(i))\n",
        "  plt.axis('off')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w0FFhkikQABy"
      },
      "source": [
        "### Load Hub Module"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cLAUiWfEQAB5"
      },
      "outputs": [],
      "source": [
        "hub_handle = 'https://tfhub.dev/google/tweening_conv3d_bair/1'\n",
        "module = hub.load(hub_handle).signatures['default']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PVHTdXnhbGsK"
      },
      "source": [
        "### Generate and show the videos"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FHAwBW-zyegP"
      },
      "outputs": [],
      "source": [
        "filled_frames = module(input_frames)['default'] / 255.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tVesWHTnSW1Z"
      },
      "outputs": [],
      "source": [
        "# Show sequences of generated video frames.\n",
        "\n",
        "# Concatenate start/end frames and the generated filled frames for the new videos.\n",
        "generated_videos = np.concatenate([input_frames[:, :1] / 255.0, filled_frames, input_frames[:, 1:] / 255.0], axis=1)\n",
        "\n",
        "for video_id in range(4):\n",
        "  fig = plt.figure(figsize=(10 * 2, 2))\n",
        "  for frame_id in range(1, 16):\n",
        "    ax = fig.add_axes([frame_id * 1 / 16., 0, (frame_id + 1) * 1 / 16., 1],\n",
        "                      xmargin=0, ymargin=0)\n",
        "    ax.imshow(generated_videos[video_id, frame_id])\n",
        "    ax.axis('off')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Q4DN769E2O_R"
      ],
      "name": "tweening_conv3d.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
